/**********************************************************************/
/*
   Author: Karan Makkar, adapted from Michelle's code
   Created: Nov 2023
   Description: Define regression programs for lee bounds regressions.

   Note: to run, run the corresponding master.do file
*/
/**********************************************************************/

* Define OLS Lee bounds regression
* Inputs:
* name: name of regression
* y: outcome variable
* x: treatment indicator
* c: control vars
* absorb: everything in absorb()
* vce: everything in vce() or clustering levels
* precomma: if conditions for cuts
* Ordered as in actual regression

  cap program drop lee_ols_reg
  program define lee_ols_reg
    version 17.0
    args name y x c precomma absorb vce varcount

    di "Running: OLS Regression for `y' - OLS"
		di "`c(current_time)'"

    * Run regression
    eststo `name' : reghdfe `y' i.`x' `c' if `precomma', `absorb' `vce'

    * Get control complier mean
    local beta_iv = _b[1.`x']
    qui sum `y' if `precomma' & `x' == 1
    qui estadd scalar control_mean = round(`r(mean)' - `beta_iv', 0.001)

    cap drop selection
	gen selection = e(sample)

    matrix define obs = e(N)
    mat colnames obs = win_in_batch
    matrix list obs
    estadd matrix obs
        
    ** tightened leebounds 
    * predicted outcomes
    
    local max_quantile = 10
    cap drop pred_outcome*
    reghdfe `y' age i.ed_cat i.gender i.test_score_bins stratum_win_prob c.stratum_win_prob#i.batch if selection==1, $demog_batch $vce
    qui predict pred_outcome 
    qui xtile pred_outcome_q =pred_outcome, nquantiles(`max_quantile')

    ** tightened leebounds
    
    leebounds `y' win_in_batch, cieffect tight(pred_outcome_q) select(selection)

    di "`e(cellsel)'"
    if "`e(cellsel)'" == "hetero" {
        local n = `max_quantile'
        while "`e(cellsel)'" == "hetero" & `n'>2 {
        local n = `n' - 1
        di "n : `n' "
        drop pred_outcome_q
        qui xtile pred_outcome_q =pred_outcome, nquantiles(`n')
        leebounds `y' win_in_batch, cieffect tight(pred_outcome_q) select(selection)
        di "`e(cellsel)'"
        }
        // cannot find the quantile
        if `n' <= 2 {
            drop pred_outcome_q
            gen pred_outcome_q = 1
        }
    }


    sum selection if win_in_batch == 0
    local select_0 = `r(mean)'
    sum selection if win_in_batch == 1
    local select_1 = `r(mean)'

    * reverse if the control group has been surveyed further (lower attrition)
    if `select_0' > `select_1' {
        clonevar win_in_batch_rev = win_in_batch
        replace win_in_batch_rev = 1 if win_in_batch==0
        replace win_in_batch_rev = 0 if win_in_batch==1
        rename win_in_batch win_in_batch_prev
        rename win_in_batch_rev win_in_batch
        }

    *Correct for "ecological fallacy"/weighting problem
    * If quantiles don't "agree" with aggregate, use untightened bounds
    local loop = 1
    levelsof pred_outcome_q
    foreach level in `r(levels)' {
        if `loop' == 0 continue
        summ selection if win_in_batch == 1 & pred_outcome_q == `level'
        local treat_select = `r(mean)'
        summ selection if win_in_batch == 0 & pred_outcome_q == `level'
        
        if `r(mean)' <= `treat_select' continue
        else {
            *drop pred_outcome_q
            replace pred_outcome_q = 1
            local loop = 0
            di "`loop'"
        }
    }
    * # of always takers
    qui count if selection == 1 & win_in_batch == 0
    local always_takers = `r(N)'

    * calculate bounds within each cell
    scalar ub = 0
    scalar lb = 0
    scalar vlb = 0 
    scalar vub = 0 
    scalar weight = 0
    scalar ubttreat_sq = 0
    scalar lbttreat_sq = 0
    
    local num = 0
    tab pred_outcome_q
    local ncat = r(r)
        
    levelsof pred_outcome_q

    foreach level in `r(levels)' {
        local num = `num' + 1 
    
        * create trimmed vars
        gen outcome_t = `y' if selection==1 & win_in_batch == 1 & pred_outcome_q == `level'
        gen outcome_c = `y' if selection==1 & win_in_batch == 0 & pred_outcome_q == `level'

        * randomize ties
        gen rand_tie = runiform() if pred_outcome_q == `level'
        sort outcome_t rand, stable
        gen order = _n if outcome_t != .
        summ order
        gen quantile = order / `r(N)'

        * calculate trimming thresholds
        summ selection if win_in_batch == 1 & pred_outcome_q == `level'
        local treat_select = `r(mean)'
        summ selection if win_in_batch == 0 & pred_outcome_q == `level'
        local q`level' = 1 - (`r(mean)' / `treat_select')

        * trimmed samples

        if `select_1' >= `select_0'{
            gen `y'_ub = outcome_t if quantile >= `q`level''
            replace `y'_ub = outcome_c if win_in_batch == 0
                
            gen `y'_lb = outcome_t if quantile <= (1 - `q`level'')
            replace `y'_lb = outcome_c if win_in_batch == 0
        }

        if `select_0' > `select_1'{
            gen `y'_ub = outcome_t if quantile <= (1 - `q`level'')
            replace `y'_ub = outcome_c if win_in_batch == 0
                
            gen `y'_lb = outcome_t if quantile >= `q`level''
            replace `y'_lb = outcome_c if win_in_batch == 0
        }

        tab selection win_in_batch if pred_outcome_q == `level' , matcell(ctst)
                
        local nall = r(N)
        mat ctst = ctst/`nall'
        local est   = ctst[2,2]
        local esnt  = ctst[2,1]
        local et    = ctst[1,2]+ctst[2,2]
        local oddsc = ctst[1,1]/ctst[2,1]
        local oddst = ctst[1,2]/ctst[2,2]
                        
        **** asymptotic variance
        *** lower bound 
        local itrim = 100- 100*`q`level''
        _pctile `y' if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level', percentiles(`itrim')
        local lth = r(r1)

        sum `y'_lb if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level'
        local lub = r(mean)

        ** Analytic Variance
        local vp = (1-`q`level'')^2*(`oddst'/(`et')+`oddsc'/(1-`et'))
        di "`vp'"
        local vb1 = r(Var)/(`est'*(1-`q`level''))
        di "`vb1'"
        local vb2 = (`lth'-`lub')^2*(`q`level'')*(`est'*(1-`q`level''))^-1
        di "`vb2'"
        local vb3 = ((`lth'-`lub')/(1-`q`level''))^2*`vp'
        di "`vb3'"

        scalar vlb`level' = `vb1'+`vb2'+`vb3'
        
        di "vlb:"
        di vlb`level'
        
        *** upper bound 
        local itrim = 100*`q`level''
        _pctile `y' if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level', percentiles(`itrim')
        local uth = r(r1)
                
        sum `y'_ub if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level'
        local tub = r(mean)
        
        ** Analytic Variance
        local vp = (1-`q`level'')^2*(`oddst'/`et'+`oddsc'/(1-`et'))
        local vb1 = r(Var)/(`est'*(1-`q`level''))
        local vb2 = (`uth'-`tub')^2*(`q`level'')*(`est'*(1-`q`level''))^-1
        local vb3 = ((`uth'-`tub')/(1-`q`level''))^2*`vp'
        
        scalar vub`level' = `vb1'+`vb2'+`vb3'
        di "VUB: vub`level'"
    
        * weight
        count if selection == 1 & win_in_batch == 0 & pred_outcome_q == `level'
        scalar weight`level' = `r(N)' / `always_takers'

        * regressions
        reghdfe `y'_lb win_in_batch stratum_win_prob c.stratum_win_prob#i.batch if selection==1 , $demog_batch  $vce
        scalar lb`level' = _b[win_in_batch]

        if `select_0' > `select_1' {
            scalar lb`level' = - lb`level'
        }

        scalar lb = lb + (lb`level' * weight`level')
        
        reghdfe `y'_ub win_in_batch stratum_win_prob c.stratum_win_prob#i.batch if selection==1 , $demog_batch  $vce
        scalar ub`level' = _b[win_in_batch]

        if `select_0' > `select_1' {
            scalar ub`level' = - ub`level'
        }

        scalar ub = ub + (ub`level' * weight`level')

        di "weight"
        di weight`level'
        
        scalar weight = weight + weight`level'
            
        scalar vlb = vlb + (weight`level')*vlb`level'
        scalar vub = vub + (weight`level')*vub`level'
        
        di "treat_sq"
        scalar ubttreat_sq = ubttreat_sq + weight`level'*(ub`level'-`tub')^2
        scalar lbttreat_sq = lbttreat_sq + weight`level'*(lb`level'-`lub')^2
        
        di "vlb"
        di vlb
        di "ubttreat_sq"
        di ubttreat_sq
        
        * for the last one Chamberlain 1994: The asymptotic variance is the sum of two components: 1) the (weighted) average of the asymptotic variance for each group (Λ1 in Chamberlain (1994), 2) the (weighted) average squared deviation of each group's estimate from the "Total" mean (Λ2 in Chamberlain (1994)). "Total" refers to the square root of the sum the squared components.  https://www.princeton.edu/~davidlee/wp/resrevision8.pdf 
                            
        if `num' == `ncat' {
            di "last ncat"
            di ubttreat_sq
            di lbttreat_sq
            di vlb
            di weight
            
            sum `y' if win_in_batch == 0 & selection == 1 
            local vc = r(Var)/r(sum_w)
        
            scalar Vub_tot_sq = sqrt(vub + ubttreat_sq + `vc')
            scalar Vlb_tot_sq = sqrt(vlb + lbttreat_sq + `vc')
            
// 				scalar vub = (vub/(weight)^2+((ubttreat_sq/weight)-(ub/weight)^2))/`ntall'+`vc'
// 				scalar vlb = (vlb/(weight)^2+((lbttreat_sq/weight)-(lb/weight)^2))/`ntall'+`vc'
// 				scalar vlb_sq = sqrt(vlb)
// 				scalar vub_sq = sqrt(vub)
        }

        * drop
        drop outcome_* order rand_tie quantile `y'_lb `y'_ub

    } // end of loop through quantiles
    
    tab selection win_in_batch, matcell(ctst)	
    local nall = r(N)
    scalar nall_sq = sqrt(`nall')
        
    di lb
    di ub
    
    matrix define lb_tight = lb
    matrix define ub_tight = ub

    mat colnames lb_tight = win_in_batch
    mat colnames ub_tight = win_in_batch
    est restore `name'
    estadd matrix lb_tight
    estadd matrix ub_tight

    estadd scalar tight_lb = lb 
    estadd scalar tight_ub = ub

    estadd scalar tight_lb_se = Vlb_tot_sq
    estadd scalar tight_ub_se = Vub_tot_sq
        
    estadd scalar n_sq = nall_sq

    local lb_tight_ci = `e(tight_lb)' - 1.96*`e(tight_lb_se)'/`e(n_sq)' 	
    local lb_tight_ci: di %4.2f `lb_tight_ci'
    estadd scalar lb_tight_ci `lb_tight_ci'

	local ub_tight_ci = `e(tight_ub)' + 1.96*`e(tight_ub_se)'/`e(n_sq)'
    local ub_tight_ci: di %4.2f `ub_tight_ci'
    estadd scalar ub_tight_ci `ub_tight_ci'

    global bounds_`varcount' = "(`lb_tight_ci', `ub_tight_ci')"
    di "varcount: `varcount'"
    di "bounds: ${bounds_`varcount'}"

    * reverse DV switch for higher control group response rate
    if `select_0' > `select_1' {
        replace win_in_batch = win_in_batch_prev
        drop win_in_batch_prev
        }
		
  end


* Define IV regression
* Inputs:
* name: name of regression
* y_var: outcome variable
* x_var: treatment indicator
* c: controls
* instrument: instrument for x
* absorb: everything in absorb()
* vce: everything in vce() or clustering levels
* precomma: if conditions for cuts
* Ordered as in actual regression

  cap program drop lee_iv_reg
  program define lee_iv_reg
    version 17.0
    args name y x instrument c precomma absorb cluster varcount

    * First, run IV regression
    di "Running: IV Regression for `y'"
		di "`c(current_time)'"

    * Run regression
    eststo `name' : ivreghdfe `y' (i.`x' = `instrument') `c' if `precomma', `absorb' `cluster'

    * Get control complier mean
    local beta_iv = _b[1.`x']
    qui sum `y' if `precomma' & `x' == 1
    qui estadd scalar control_mean = round(`r(mean)' - `beta_iv', 0.001)

    * Save first stage
    reghdfe `x' win_in_batch `c' if `precomma', `absorb' `vce'
    local first_stage = _b[win_in_batch]

    * Now, save lee bounds
    reghdfe `y' win_in_batch `c' if `precomma', `absorb' `vce'
    
    cap drop selection
	gen selection = e(sample)

    matrix define obs = e(N)
    mat colnames obs = win_in_batch
    matrix list obs
    estadd matrix obs
        
    ** tightened leebounds 
    * predicted outcomes
    
    local max_quantile = 10
    cap drop pred_outcome*
    reghdfe `y' age i.ed_cat i.gender i.test_score_bins stratum_win_prob c.stratum_win_prob#i.batch if selection==1, $demog_batch $vce
    qui predict pred_outcome 
    qui xtile pred_outcome_q =pred_outcome, nquantiles(`max_quantile')

    ** tightened leebounds
    
    leebounds `y' win_in_batch, cieffect tight(pred_outcome_q) select(selection)

    di "`e(cellsel)'"
    if "`e(cellsel)'" == "hetero" {
        local n = `max_quantile'
        while "`e(cellsel)'" == "hetero" & `n'>2 {
        local n = `n' - 1
        di "n : `n' "
        drop pred_outcome_q
        qui xtile pred_outcome_q =pred_outcome, nquantiles(`n')
        leebounds `y' win_in_batch, cieffect tight(pred_outcome_q) select(selection)
        di "`e(cellsel)'"
        }
        // cannot find the quantile
        if `n' <= 2 {
            drop pred_outcome_q
            gen pred_outcome_q = 1
        }
    }


    sum selection if win_in_batch == 0
    local select_0 = `r(mean)'
    sum selection if win_in_batch == 1
    local select_1 = `r(mean)'

    * reverse if the control group has been surveyed further (lower attrition)
    if `select_0' > `select_1' {
        clonevar win_in_batch_rev = win_in_batch
        replace win_in_batch_rev = 1 if win_in_batch==0
        replace win_in_batch_rev = 0 if win_in_batch==1
        rename win_in_batch win_in_batch_prev
        rename win_in_batch_rev win_in_batch
        }

    *Correct for "ecological fallacy"/weighting problem
    * If quantiles don't "agree" with aggregate, use untightened bounds
    local loop = 1
    levelsof pred_outcome_q
    foreach level in `r(levels)' {
        if `loop' == 0 continue
        summ selection if win_in_batch == 1 & pred_outcome_q == `level'
        local treat_select = `r(mean)'
        summ selection if win_in_batch == 0 & pred_outcome_q == `level'
        
        if `r(mean)' <= `treat_select' continue
        else {
            *drop pred_outcome_q
            replace pred_outcome_q = 1
            local loop = 0
            di "`loop'"
        }
    }
    * # of always takers
    qui count if selection == 1 & win_in_batch == 0
    local always_takers = `r(N)'

    * calculate bounds within each cell
    scalar ub = 0
    scalar lb = 0
    scalar vlb = 0 
    scalar vub = 0 
    scalar weight = 0
    scalar ubttreat_sq = 0
    scalar lbttreat_sq = 0
    
    local num = 0
    tab pred_outcome_q
    local ncat = r(r)
        
    levelsof pred_outcome_q

    foreach level in `r(levels)' {
        local num = `num' + 1 
    
        * create trimmed vars
        gen outcome_t = `y' if selection==1 & win_in_batch == 1 & pred_outcome_q == `level'
        gen outcome_c = `y' if selection==1 & win_in_batch == 0 & pred_outcome_q == `level'

        * randomize ties
        gen rand_tie = runiform() if pred_outcome_q == `level'
        sort outcome_t rand, stable
        gen order = _n if outcome_t != .
        summ order
        gen quantile = order / `r(N)'

        * calculate trimming thresholds
        summ selection if win_in_batch == 1 & pred_outcome_q == `level'
        local treat_select = `r(mean)'
        summ selection if win_in_batch == 0 & pred_outcome_q == `level'
        local q`level' = 1 - (`r(mean)' / `treat_select')

        * trimmed samples

        if `select_1' >= `select_0'{
            gen `y'_ub = outcome_t if quantile >= `q`level''
            replace `y'_ub = outcome_c if win_in_batch == 0
                
            gen `y'_lb = outcome_t if quantile <= (1 - `q`level'')
            replace `y'_lb = outcome_c if win_in_batch == 0
        }

        if `select_0' > `select_1'{
            gen `y'_ub = outcome_t if quantile <= (1 - `q`level'')
            replace `y'_ub = outcome_c if win_in_batch == 0
                
            gen `y'_lb = outcome_t if quantile >= `q`level''
            replace `y'_lb = outcome_c if win_in_batch == 0
        }

        tab selection win_in_batch if pred_outcome_q == `level' , matcell(ctst)
                
        local nall = r(N)
        mat ctst = ctst/`nall'
        local est   = ctst[2,2]
        local esnt  = ctst[2,1]
        local et    = ctst[1,2]+ctst[2,2]
        local oddsc = ctst[1,1]/ctst[2,1]
        local oddst = ctst[1,2]/ctst[2,2]
                        
        **** asymptotic variance
        *** lower bound 
        local itrim = 100- 100*`q`level''
        _pctile `y' if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level', percentiles(`itrim')
        local lth = r(r1)

        sum `y'_lb if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level'
        local lub = r(mean)

        ** Analytic Variance
        local vp = (1-`q`level'')^2*(`oddst'/(`et')+`oddsc'/(1-`et'))
        di "`vp'"
        local vb1 = r(Var)/(`est'*(1-`q`level''))
        di "`vb1'"
        local vb2 = (`lth'-`lub')^2*(`q`level'')*(`est'*(1-`q`level''))^-1
        di "`vb2'"
        local vb3 = ((`lth'-`lub')/(1-`q`level''))^2*`vp'
        di "`vb3'"

        scalar vlb`level' = `vb1'+`vb2'+`vb3'
        
        di "vlb:"
        di vlb`level'
        
        *** upper bound 
        local itrim = 100*`q`level''
        _pctile `y' if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level', percentiles(`itrim')
        local uth = r(r1)
                
        sum `y'_ub if selection == 1 & win_in_batch == 1 & pred_outcome_q == `level'
        local tub = r(mean)
        
        ** Analytic Variance
        local vp = (1-`q`level'')^2*(`oddst'/`et'+`oddsc'/(1-`et'))
        local vb1 = r(Var)/(`est'*(1-`q`level''))
        local vb2 = (`uth'-`tub')^2*(`q`level'')*(`est'*(1-`q`level''))^-1
        local vb3 = ((`uth'-`tub')/(1-`q`level''))^2*`vp'
        
        scalar vub`level' = `vb1'+`vb2'+`vb3'
        di "VUB: vub`level'"
    
        * weight
        count if selection == 1 & win_in_batch == 0 & pred_outcome_q == `level'
        scalar weight`level' = `r(N)' / `always_takers'

        * regressions
        reghdfe `y'_lb win_in_batch stratum_win_prob c.stratum_win_prob#i.batch if selection==1 , $demog_batch  $vce
        scalar lb`level' = _b[win_in_batch]

        if `select_0' > `select_1' {
            scalar lb`level' = - lb`level'
        }

        scalar lb = lb + (lb`level' * weight`level')
        
        reghdfe `y'_ub win_in_batch stratum_win_prob c.stratum_win_prob#i.batch if selection==1 , $demog_batch  $vce
        scalar ub`level' = _b[win_in_batch]

        if `select_0' > `select_1' {
            scalar ub`level' = - ub`level'
        }

        scalar ub = ub + (ub`level' * weight`level')

        di "weight"
        di weight`level'
        
        scalar weight = weight + weight`level'
            
        scalar vlb = vlb + (weight`level')*vlb`level'
        scalar vub = vub + (weight`level')*vub`level'
        
        di "treat_sq"
        scalar ubttreat_sq = ubttreat_sq + weight`level'*(ub`level'-`tub')^2
        scalar lbttreat_sq = lbttreat_sq + weight`level'*(lb`level'-`lub')^2
        
        di "vlb"
        di vlb
        di "ubttreat_sq"
        di ubttreat_sq
        
        * for the last one Chamberlain 1994: The asymptotic variance is the sum of two components: 1) the (weighted) average of the asymptotic variance for each group (Λ1 in Chamberlain (1994), 2) the (weighted) average squared deviation of each group's estimate from the "Total" mean (Λ2 in Chamberlain (1994)). "Total" refers to the square root of the sum the squared components.  https://www.princeton.edu/~davidlee/wp/resrevision8.pdf 
                            
        if `num' == `ncat' {
            di "last ncat"
            di ubttreat_sq
            di lbttreat_sq
            di vlb
            di weight
            
            sum `y' if win_in_batch == 0 & selection == 1 
            local vc = r(Var)/r(sum_w)
        
            scalar Vub_tot_sq = sqrt(vub + ubttreat_sq + `vc')
            scalar Vlb_tot_sq = sqrt(vlb + lbttreat_sq + `vc')
            
// 				scalar vub = (vub/(weight)^2+((ubttreat_sq/weight)-(ub/weight)^2))/`ntall'+`vc'
// 				scalar vlb = (vlb/(weight)^2+((lbttreat_sq/weight)-(lb/weight)^2))/`ntall'+`vc'
// 				scalar vlb_sq = sqrt(vlb)
// 				scalar vub_sq = sqrt(vub)
        }

        * drop
        drop outcome_* order rand_tie quantile `y'_lb `y'_ub

    } // end of loop through quantiles
    
    tab selection win_in_batch, matcell(ctst)	
    local nall = r(N)
    scalar nall_sq = sqrt(`nall')
        
    di lb
    di ub
    
    matrix define lb_tight = lb
    matrix define ub_tight = ub

    mat colnames lb_tight = win_in_batch
    mat colnames ub_tight = win_in_batch
    est restore `name'
    estadd matrix lb_tight
    estadd matrix ub_tight

    estadd scalar tight_lb = lb 
    estadd scalar tight_ub = ub

    estadd scalar tight_lb_se = Vlb_tot_sq
    estadd scalar tight_ub_se = Vub_tot_sq
        
    estadd scalar n_sq = nall_sq

    local lb_tight_ci = (`e(tight_lb)' - 1.96*`e(tight_lb_se)'/`e(n_sq)')/`first_stage'
    local lb_tight_ci: di %4.2f `lb_tight_ci'
    estadd scalar lb_tight_ci `lb_tight_ci'

	local ub_tight_ci = (`e(tight_ub)' + 1.96*`e(tight_ub_se)'/`e(n_sq)')/`first_stage'
    local ub_tight_ci: di %4.2f `ub_tight_ci'
    estadd scalar ub_tight_ci `ub_tight_ci'

    global bounds_`varcount' = "(`lb_tight_ci', `ub_tight_ci')"
    di "varcount: `varcount'"
    di "bounds: ${bounds_`varcount'}"
    
    * reverse DV switch for higher control group response rate
    if `select_0' > `select_1' {
        replace win_in_batch = win_in_batch_prev
        drop win_in_batch_prev
        }
		
  end
// END
